import java.util.*;

public class Solution1171 {

    private Map<Integer, List<ListNode>> getPrefixDataMap(ListNode head){
        Map<Integer, List<ListNode>> prefixDataMap = new HashMap<>();
        int prefix = 0;
        for (ListNode tmpNode = head; tmpNode != null; tmpNode = tmpNode.next) {
            prefix += tmpNode.val;
            List<ListNode> tmpListNodeList = prefixDataMap.getOrDefault(prefix, new ArrayList<>());
            tmpListNodeList.add(tmpNode);
            prefixDataMap.put(prefix, tmpListNodeList);
        }
        return prefixDataMap;
    }

    public ListNode removeZeroSumSublists(ListNode head) {
        ListNode preHead = new ListNode(0, head);
        for (boolean isNeedToReset = true; isNeedToReset; ){
            isNeedToReset = false;
            Map<Integer, List<ListNode>> prefixDataMap = getPrefixDataMap(preHead);
            for (Map.Entry<Integer, List<ListNode>> et : prefixDataMap.entrySet()) {
                List<ListNode> listNodeList = et.getValue();
                if (listNodeList.size() > 1){
                    isNeedToReset = true;
                    listNodeList.get(0).next = listNodeList.get(listNodeList.size() - 1).next;
                }
            }
        }
        return preHead.next;
    }

    private ListNode buildAndGetHead(List<Integer> integerList){
        ListNode tmpNode = new ListNode();
        ListNode head = tmpNode;
        for (int i = 0; i < integerList.size(); i++, tmpNode = tmpNode.next) {
            Integer tmpVal = integerList.get(i);
            tmpNode.next = new ListNode(tmpVal);
        }
        return head.next;
    }

    public static void main(String[] args) {
        Solution1171 s = new Solution1171();
        ListNode head = s.buildAndGetHead(Arrays.asList(1,3,2,-3,-2,5,5,-5,1));
        ListNode res = s.removeZeroSumSublists(head);
        System.out.println(res);
    }
}
